{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h8Xy0vuPg-gS"
      },
      "source": [
        "```\n",
        "Copyright 2023 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c_XzbBLngvoM"
      },
      "source": [
        "# FSQ\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bmQq4IJKBSiO"
      },
      "outputs": [],
      "source": [
        "import itertools\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LK5u500Vad2P"
      },
      "outputs": [],
      "source": [
        "Codeword = jax.Array\n",
        "Indices = jax.Array\n",
        "\n",
        "\n",
        "def round_ste(z):\n",
        "  \"\"\"Round with straight through gradients.\"\"\"\n",
        "  zhat = jnp.round(z)\n",
        "  return z + jax.lax.stop_gradient(zhat - z)\n",
        "\n",
        "\n",
        "class FSQ:\n",
        "  \"\"\"Quantizer.\"\"\"\n",
        "\n",
        "  def __init__(self, levels: list[int], eps: float = 1e-3):\n",
        "    self._levels = levels\n",
        "    self._eps = eps\n",
        "    self._levels_np = np.asarray(levels)\n",
        "    self._basis = np.concatenate(\n",
        "        ([1], np.cumprod(self._levels_np[:-1]))).astype(np.uint32)\n",
        "\n",
        "    self._implicit_codebook = self.indexes_to_codes(\n",
        "        np.arange(self.codebook_size))\n",
        "\n",
        "  @property\n",
        "  def num_dimensions(self) -\u003e int:\n",
        "    \"\"\"Number of dimensions expected from inputs.\"\"\"\n",
        "    return len(self._levels)\n",
        "\n",
        "  @property\n",
        "  def codebook_size(self) -\u003e int:\n",
        "    \"\"\"Size of the codebook.\"\"\"\n",
        "    return np.prod(self._levels)\n",
        "\n",
        "  @property\n",
        "  def codebook(self):\n",
        "    \"\"\"Returns the implicit codebook. Shape (prod(levels), num_dimensions).\"\"\"\n",
        "    return self._implicit_codebook\n",
        "\n",
        "  def bound(self, z: jax.Array) -\u003e jax.Array:\n",
        "    \"\"\"Bound `z`, an array of shape (..., d).\"\"\"\n",
        "    half_l = (self._levels_np - 1) * (1 - self._eps) / 2\n",
        "    offset = jnp.where(self._levels_np % 2 == 1, 0.0, 0.5)\n",
        "    shift = jnp.tan(offset / half_l)\n",
        "    return jnp.tanh(z + shift) * half_l - offset\n",
        "\n",
        "  def quantize(self, z: jax.Array) -\u003e Codeword:\n",
        "    \"\"\"Quanitzes z, returns quantized zhat, same shape as z.\"\"\"\n",
        "    quantized = round_ste(self.bound(z))\n",
        "\n",
        "    # Renormalize to [-1, 1].\n",
        "    half_width = self._levels_np // 2\n",
        "    return quantized / half_width\n",
        "\n",
        "  def _scale_and_shift(self, zhat_normalized):\n",
        "    # Scale and shift to range [0, ..., L-1]\n",
        "    half_width = self._levels_np // 2\n",
        "    return (zhat_normalized * half_width) + half_width\n",
        "\n",
        "  def _scale_and_shift_inverse(self, zhat):\n",
        "    half_width = self._levels_np // 2\n",
        "    return (zhat - half_width) / half_width\n",
        "\n",
        "  def codes_to_indexes(self, zhat: Codeword) -\u003e Indices:\n",
        "    \"\"\"Converts a `code` to an index in the codebook.\"\"\"\n",
        "    assert zhat.shape[-1] == self.num_dimensions\n",
        "    zhat = self._scale_and_shift(zhat)\n",
        "    return (zhat * self._basis).sum(axis=-1).astype(jnp.uint32)\n",
        "\n",
        "  def indexes_to_codes(self, indices: Indices) -\u003e Codeword:\n",
        "    \"\"\"Inverse of `indexes_to_codes`.\"\"\"\n",
        "    indices = indices[..., jnp.newaxis]\n",
        "    codes_non_centered = np.mod(\n",
        "        np.floor_divide(indices, self._basis), self._levels_np\n",
        "    )\n",
        "    return self._scale_and_shift_inverse(codes_non_centered)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X0JkjUv6Rl0Z"
      },
      "source": [
        "# Example usage"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PkrLkcvbRnbN"
      },
      "outputs": [],
      "source": [
        "fsq = FSQ(levels=[3, 5, 4])\n",
        "\n",
        "z = np.asarray([0.25, 0.6, -7])\n",
        "zhat = fsq.quantize(z)\n",
        "print(f\"Quantized {z} -\u003e {zhat}\")\n",
        "\n",
        "# We can map to an index in the codebook.\n",
        "idx = fsq.codes_to_indexes(zhat)\n",
        "print(f\"Code {zhat} is the {idx}-th index.\")\n",
        "\n",
        "# Back to code\n",
        "code_out = fsq.indexes_to_codes(idx)\n",
        "print(f\"Index {idx} mapped back to {zhat}.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DnQBFfftSmyy"
      },
      "source": [
        "# Quantizing a multi-dimensional bottleneck"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kkJxytDvSqqL"
      },
      "outputs": [],
      "source": [
        "fsq = FSQ(levels=[5, 4, 3])\n",
        "\n",
        "d = fsq.num_dimensions\n",
        "z = np.random.uniform(size=(3, 8, 8, d))\n",
        "zhat = fsq.quantize(z)\n",
        "assert zhat.shape == (3, 8, 8, d)\n",
        "\n",
        "indices = fsq.codes_to_indexes(zhat)\n",
        "assert indices.shape == (3, 8, 8)\n",
        "\n",
        "zhat_out = fsq.indexes_to_codes(indices)\n",
        "assert zhat_out.shape == zhat.shape\n",
        "\n",
        "np.testing.assert_allclose(zhat, zhat_out)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i5BuPjRsTbjK"
      },
      "source": [
        "# Validating codebook"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "epeDl8MmRqSj"
      },
      "outputs": [],
      "source": [
        "fsq = FSQ(levels=[3, 4])\n",
        "print(fsq.codebook)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1b_TvS9wHAx_fel0mQrvPT838XtvC71AA",
          "timestamp": 1695665937803
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
